from ril.util.net_factory import mlp_factory, dueling_factory
from ril.grid_world.rule import grid_world_rule
from ril.util.policy_factory import common_policy_factory

grid_world_net = {
    'dqn': mlp_factory,
    'ddqn': mlp_factory,
    'dueling dqn': dueling_factory({
        'hidden_sizes': [64]
    }, {
        'hidden_sizes': [64]
    }, 'mlp'),
    'pg': mlp_factory,
    'discrete sac': mlp_factory,
    'a2c': mlp_factory,
    'ppo': mlp_factory,
}


def grid_world_policy(agent: str):
    return common_policy_factory(agent, grid_world_rule)
